// file:arch/x86/arch/bus/pci.c
// autor: jiang xinpeng
// time:2021.2.1
// copyright:(C) 2020-2050 by jiangxinpeng,All right are reserved.

#include <arch/x86.h>
#include <arch/pci.h>
#include <os/debug.h>
#include <lib/stddef.h>

pci_info_t pci;           // system pci info
static int bus_count = 0; // number of all bus
static int dev_count = 0; // number of all device

static uint8_t PciConfigReadB(uint8_t bus, uint8_t device, uint8_t function, uint8_t offset);
static uint16_t PciConfigReadW(uint8_t bus, uint8_t device, uint8_t function, uint8_t offset);
static uint32_t PciConfigReadD(uint8_t bus, uint8_t device, uint8_t function, uint8_t offset);
static void PciConfigWriteB(uint8_t bus, uint8_t device, uint8_t function, uint8_t offset, uint8_t data);
static void PciConfigWriteW(uint8_t bus, uint8_t device, uint8_t function, uint8_t offset, uint16_t data);
static void PciConfigWriteD(uint8_t bus, uint8_t device, uint8_t function, uint8_t offset, uint32_t data);

static void PciBrigeEnum(uint8_t bus, uint8_t dev, uint8_t fun);
static void PciFunEnum(uint8_t bus, uint8_t dev);

static int PciVendorCheck(uint8_t bus, uint8_t device, uint8_t function);
static int PciBusCheck(uint8_t bus);
static int PciDeviceCheck(uint8_t bus, uint8_t device);
static int PciFunctionCheck(uint8_t bus, uint8_t device, uint8_t function);
static int PciMuitiFunCheck(uint8_t bus, uint8_t device);

static pci_dev_t *PciDeviceAlloc();
static void PciDeviceDel(pci_dev_t *device);
static pci_bus_t *PciBusAlloc();
static void PciBusDel(pci_bus_t *bus);

static uint8_t PciGetClassCode(uint8_t bus, uint8_t dev, uint8_t fun);
static uint8_t PciGetSubClassCode(uint8_t bus, uint8_t dev, uint8_t fun);
static uint8_t PciGetSubBus(uint8_t bus, uint8_t device, uint8_t function);
static uint8_t PciGetHeader(uint8_t bus, uint8_t device, uint8_t function);
static uint16_t PciGetRevisionID(uint8_t bus, uint8_t device, uint8_t function);

static uint8_t PciConfigReadB(uint8_t bus, uint8_t device, uint8_t function, uint8_t offset)
{
    uint8_t data;
    uint8_t diff = offset % 4 * 8;
    uint32_t temp;

    offset = offset - offset % 4;

    temp = PciConfigReadD(bus, device, function, offset);
    data = (uint8_t)(temp >> diff);
    return data;
}

static uint16_t PciConfigReadW(uint8_t bus, uint8_t device, uint8_t function, uint8_t offset)
{
    uint8_t diff = offset % 4 * 8;
    uint16_t data;
    uint32_t temp;

    offset = offset - offset % 4;

    temp = PciConfigReadD(bus, device, function, offset);
    data = (uint16_t)(temp >> diff);

    return data;
}

static uint32_t PciConfigReadD(uint8_t bus, uint8_t device, uint8_t function, uint8_t offset)
{
    uint32_t data;
    uint32_t address = (uint32_t)bus << PCI_CONFIG_ADDRESS_BUSNUM_OFF | (uint32_t)device << PCI_CONFIG_ADDRESS_DEVICENUM_OFF | (uint32_t)function << PCI_CONFIG_ADDRESS_FUNCTIONNUM_OFF | (uint32_t)(offset & PCI_CONFIG_ADDRESS_REGOFF_MASK) | 0x80000000;

    Out32(PCI_CONFIG_ADDRESS, address);
    data = In32(PCI_CONFIG_DATA);

    return data;
}

static void PciConfigWriteB(uint8_t bus, uint8_t device, uint8_t function, uint8_t offset, uint8_t data)
{
    uint32_t temp;
    uint32_t diff = offset % 4 * 8;

    offset = offset / 4;
    // beacause pci config data register is 32bits(dword),we no any ways to direct write word,may be will happened error,such as
    // you may be use a word data direct cover dword data,beacause you assume data filed only have 16bits,hardware will import this word data to dword.
    // so we use another mod to write a word
    // read a dword to temp,then use "or" keys to make 16bits data become the part of 32bits data
    // then use function to write dword data
    temp = PciConfigReadD(bus, device, function, offset);
    temp = temp | ((uint32_t)data << diff);
    PciConfigWriteD(bus, device, function, offset, temp);
}

static void PciConfigWriteW(uint8_t bus, uint8_t device, uint8_t function, uint8_t offset, uint16_t data)
{
    uint32_t temp;
    uint8_t diff = offset % 4 * 8;

    offset = offset / 4;
    // beacause pci config data register is 32bits(dword),we no any ways to direct write word,may be will happened error,such as
    // you may be use a word data direct cover dword data,beacause you assume data filed only have 16bits,hardware will import this word data to dword.
    // so we use another mod to write a word
    // read a dword to temp,then use "or" keys to make 16bits data become the part of 32bits data
    // then use function to write dword data
    temp = PciConfigReadD(bus, device, function, offset);
    temp = temp | ((uint32_t)data << diff);
    PciConfigWriteD(bus, device, function, offset, temp);
}

static void PciConfigWriteD(uint8_t bus, uint8_t device, uint8_t function, uint8_t offset, uint32_t data)
{
    uint32_t address = (uint32_t)bus << PCI_CONFIG_ADDRESS_BUSNUM_OFF | (uint32_t)device << PCI_CONFIG_ADDRESS_DEVICENUM_OFF | (uint32_t)function << PCI_CONFIG_ADDRESS_FUNCTIONNUM_OFF | (uint32_t)(offset & PCI_CONFIG_ADDRESS_REGOFF_MASK) | (uint32_t)0x80000000;

    Out32(PCI_CONFIG_ADDRESS, address);
    Out32(PCI_CONFIG_DATA, data);
}

static uint16_t PciGetVendorID(uint8_t bus, uint8_t dev, uint8_t fun)
{
    return PciConfigReadW(bus, dev, fun, 0x00);
}

static uint16_t PciGetDeviceID(uint8_t bus, uint8_t dev, uint8_t fun)
{
    return PciConfigReadW(bus, dev, fun, 0x00 + 2);
}

// check vendor if is 0xffff
static int PciVendorCheck(uint8_t bus, uint8_t device, uint8_t function)
{
    uint16_t vendor;
    // read a word from pci config space first word，if is unused,just is 0xffff
    vendor = PciConfigReadW(bus, device, function, 0);
    if (vendor != 0xffff)
    {
        return 0;
    }
    return 1;
}

static pci_dev_t *PciDeviceAlloc()
{
    int i;

    for (i = 0; i < PCI_DEVICE_MAX; i++)
    {
        if (pci.device[i].flags != PCI_DEVICE_USING)
        {
            return &pci.device[i];
        }
    }
    return NULL;
}

static void PciDeviceDel(pci_dev_t *device)
{
    if (!device)
        return;
    device->flags = PCI_DEVICE_INVALID;
}

static pci_bus_t *PciBusAlloc()
{
    int i;
    for (i = 0; i < PCI_BUS_MAX; i++)
    {
        if (pci.bus[i].flags != PCI_DEVICE_USING)
            return &pci.bus[i];
    }
    return NULL;
}

static void PciBusFree(pci_bus_t *bus)
{
    if (!bus)
        return;
    bus->flags = PCI_BUS_INVALID;
}

// check bus if is present
static int PciBusCheck(uint8_t bus)
{
    int device;

    // access assume bus all device,if all device are no-exist,then this bus also unused
    for (device = 0; device < PCI_DEV_MAX; device++)
    {
        // current device eixst
        if (!PciDeviceCheck(bus, device))
        {
            return 0;
        }
    }
    return 1;
}

// check device if is present
static int PciDeviceCheck(uint8_t bus, uint8_t device)
{
    uint8_t function = 0;
    uint16_t vendor;

    // if device no-exist,vendor ID is 0xFFFF
    if (!PciVendorCheck(bus, device, function))
    {
        return 0;
    }
    return 1;
}

static int PciFunctionCheck(uint8_t bus, uint8_t device, uint8_t function)
{
    // check specific function if is present
    // if function no present,vendor is 0xFFFF
    if (!PciVendorCheck(bus, device, function))
    {
        return 0;
    }
    return 1;
}

// return value info   ok:0 failed: 1
static int PciBrigeCheck(uint8_t bus, uint8_t device, uint8_t function)
{
    uint8_t header = 0, class = 0, subclass = 0;
    // if device exists
    if (!PciDeviceCheck(bus, device))
    {
        // check class and subclass code if is pci brige
        class = PciGetClassCode(bus, device, function);
        subclass = PciGetSubClassCode(bus, device, function);
        if ((class == 0x06) && (subclass == 0x04))
        {
            return 0;
        }
    }
    return 1;
}

static uint8_t PciGetClassCode(uint8_t bus, uint8_t dev, uint8_t fun)
{
    return PciConfigReadB(bus, dev, fun, 0x08 + 3);
}

static uint8_t PciGetSubClassCode(uint8_t bus, uint8_t dev, uint8_t fun)
{
    return PciConfigReadB(bus, dev, fun, 0x08 + 2);
}

static void PciDeviceEnum(uint8_t bus)
{
    int dev = 0, count = 0;

    // enum assume bus all device
    for (dev = 0; dev < PCI_DEV_MAX; dev++)
    {
        // if this device no-exits,don't product other synax,direct go to next loop
        if (!PciDeviceCheck(bus, dev))
        {
            // scan all function
            PciFunEnum(bus, dev);
        }
    }
}

// if device is support multifunction,we need called this function to enum all used function num
static void PciFunEnum(uint8_t bus, uint8_t dev)
{
    int fun = 0, count = 0;

    // add first fun
    PciDeviceAdd(bus, dev, 0);

    // if device support multifunction
    if (!PciMuitiFunCheck(bus, dev))
    {
        // from first function,start enum
        for (fun = 1; fun < PCI_FUN_MAX; fun++)
        {
            // check specific function if is used
            if (!PciFunctionCheck(bus, dev, fun))
            {
                // every fun are a device
                PciDeviceAdd(bus, dev, fun);
                // if device is pci briger
                if (!PciBrigeCheck(bus, dev, fun))
                {
                    // continue enum briger all device
                    PciBrigeEnum(bus, dev, fun);
                }
            }
        }
    }
}

static int PciMuitiFunCheck(uint8_t bus, uint8_t device)
{
    uint32_t temp;
    uint8_t function, headertype;
    headertype = PciConfigReadB(bus, device, 0, PCI_HEADTYPE);
    // check headtype bits7 if is set,we think this device have multifunction
    if (headertype & PCI_HEADERTYPE_MULITIFUN)
    {
        return 0;
    }
    return 1;
}

// check pci host controller is single or multifunction
static int PciHostControllerCheck()
{
    uint8_t bus = 0, device = 0, function = 0;
    uint8_t headertype = 0;
    headertype = PciGetHeader(bus, device, function);
    // check bus0,device0,function0,if is multifunction device
    if (headertype & PCI_HEADERTYPE_MULITIFUN)
    {
        KPrint("[pci] pci host is multifun device\n");
        return 0;
    }
    KPrint("[pci] pci host is singlefun device\n");
    return 1;
}

static uint8_t PciGetHeader(uint8_t bus, uint8_t device, uint8_t function)
{
    uint8_t header;
    header = PciConfigReadB(bus, device, function, PCI_HEADTYPE);

    return header;
}

static void PciBrigeEnum(uint8_t bus, uint8_t dev, uint8_t fun)
{
    uint8_t device, secondbus;
    // if is pcibrige device
    if (!PciBrigeCheck(bus, dev, fun))
    {
        // get second bus number from config space
        secondbus = PciGetSubBus(bus, dev, fun);
        PciBusAdd(bus, secondbus, 0, fun); // add bus to system
        // enum current bus all device
        PciDeviceEnum(secondbus);
    }
}

static void PciScan()
{
    int bus, dev, fun;

    // if is single host cotroller,only scan one bus
    if (pci.type != PCI_MUITIHOST)
    {
        PciBusAdd(-1, -1, 0, 0); // add first bus(only one bus)
        PciDeviceEnum(PCI_BUS_FIRST);
        return;
    }
    else
    {
        // if is muiti host controller
        // scan all fun, every fun just is a bus
        for (fun = 0; fun < PCI_FUN_MAX; fun++)
        {
            PciBusAdd(-1, -1, 0, fun); // add bus with func(every function are one bus)
            // bus number equal fun number
            PciDeviceEnum(fun);
        }
    }
}

static uint8_t PciGetSubBus(uint8_t bus, uint8_t device, uint8_t function)
{
    // if pcibrige provide secondary bus equal bus0,also is system bus,indiate read error
    uint8_t secondbus = 0;

    if (!PciBrigeCheck(bus, device, function))
    {
        // read secondbus from config space
        secondbus = PciConfigReadB(bus, device, function, 0x18 + 1);
    }

    return secondbus;
}

static int PciBIST(uint8_t bus, uint8_t dev, uint8_t fun)
{
    uint8_t bist = 0, status = 1;
    uint32_t temp;

    bist = PciConfigReadB(bus, dev, fun, 0x0C + 3);
    // test bist reg bits7,if bits7 is 1,support BIST
    if (bist & 0x80)
    {
        // set bist reg bits6,start bist
        PciConfigWriteD(bus, dev, fun, 0x0C, (uint32_t)(bist | 0x40) << 24);
    }
    // read bist reg test low4bits bits0-bits3,if is 0 test successful
    bist = PciConfigReadB(bus, dev, fun, 0x0C + 3);
    if (!bist & 0xf)
    {
        return 0;
    }
    return 1;
}

static uint8_t PciGetDeviceType(uint8_t bus, uint8_t device, uint8_t function)
{
    uint8_t header, type;

    // read headertype
    header = PciGetHeader(bus, device, function);
    type = header & ~0x80;

    return type;
}

static uint8_t PciGetProIF(uint8_t bus, uint8_t device, uint8_t function)
{
    return PciConfigReadB(bus, device, function, 0x8 + 1);
}

static uint32_t PciGetIrq(uint8_t bus, uint8_t device, uint8_t function)
{
    return PciConfigReadB(bus, device, function, PCI_IRQ);
}

static uint16_t PciGetRevisionID(uint8_t bus, uint8_t device, uint8_t function)
{
    return PciConfigReadW(bus, device, function, PCI_REVISION_ID);
}

static void PciDeviceInit(pci_dev_t *device, uint8_t bus, uint8_t dev, uint8_t fun)
{
    // init device set info
    device->bus = bus;
    device->dev = dev;
    device->fun = fun;
    device->header = PciGetDeviceType(bus, dev, fun);
    device->class = PciGetClassCode(bus, dev, fun);
    device->subclass = PciGetSubClassCode(bus, dev, fun);
    device->proIF = PciGetProIF(bus, dev, fun);
    device->vendorID = PciGetVendorID(bus, dev, fun);
    device->deviceID = PciGetDeviceID(bus, dev, fun);
    device->revisionID = PciGetRevisionID(bus, dev, fun);
    device->irq = PciGetIrq(bus, dev, fun);
}

// init pci device bar
static void PciBarInit(pci_dev_t *device, uint8_t bus, uint8_t dev, uint8_t fun)
{
    int bar;
    uint32_t data, temp;
    uint8_t type;
    uint32_t base, len;

    for (bar = 0; bar < PCI_BAR_MAX; bar++)
    {
        data = PciConfigReadD(bus, dev, fun, PCI_BAR0 + bar * PCI_BAR_SIZE);
        // figure type
        type = data & 0x1;
        // figure base
        base = !type ? (data & PCI_MEMMAP_MASK) : (data & PCI_IOMAP_MASK);
        // write 0xffffffff
        PciConfigWriteD(bus, dev, fun, PCI_BAR0 + bar * PCI_BAR_SIZE, 0xffffffff);
        // read config
        temp = PciConfigReadD(bus, dev, fun, PCI_BAR0 + bar * PCI_BAR_SIZE);
        // get len data mask info NOT +1
        len = (!type) ? ((~(temp & PCI_MEMMAP_MASK)) + 1) : ((~(temp & PCI_IOMAP_MASK)) + 1);
        // write back
        PciConfigWriteD(bus, dev, fun, 0x10 + bar * PCI_BAR_SIZE, data);

        // set bar info
        device->bar[bar].base = base;
        device->bar[bar].len = len;
        device->bar[bar].type = type;
        device->bar[bar].flags = PCI_BAR_AVAILABLE;
    }
}

pci_dev_t *PciGetDevice(uint32_t vendor_id, uint32_t devide_id)
{
    int i;
    pci_dev_t *device;
    for (i = 0; i < PCI_DEV_MAX; i++)
    {
        device = &pci.device[i];

        if (device->flags & PCI_DEVICE_USING && device->vendorID == vendor_id && device->deviceID == devide_id)
        {
            return device;
        }
    }
    return NULL;
}

pci_dev_t *PciGetDeviceByClass(uint32_t class, uint32_t subclass)
{
    int i;
    pci_dev_t *device;

    for (i = 0; i < PCI_DEV_MAX; i++)
    {
        device = &pci.device[i];
        if (device->flags & PCI_DEVICE_USING && device->class == class && device->subclass == subclass)
        {
            return device;
        }
    }
    return NULL;
}

pci_dev_t *PciGetDeviceByClassAndProIF(uint32_t class, uint32_t subclass, uint32_t proIF)
{
    int i;
    pci_dev_t *device;

    for (i = 0; i < PCI_DEV_MAX; i++)
    {
        device = &pci.device[i];
        if (device->flags & PCI_DEVICE_USING && device->class == class && device->subclass == subclass && device->proIF == proIF)
        {
            return device;
        }
    }
    return NULL;
}

uint32_t PciDeviceRead(pci_dev_t *device, uint32_t reg)
{
    return PciConfigReadD(device->bus, device->dev, device->fun, reg);
}

void PciDeviceWrite(pci_dev_t *device, uint32_t reg, uint32_t data)
{
    PciConfigWriteD(device->bus, device->dev, device->fun, reg, data);
}

uint32_t PciDeviceGetIoAddr(pci_dev_t *device)
{
    int i;
    for (i = 0; i < PCI_BAR_MAX; i++)
    {
        if (device->bar[i].type == PCI_BAR_IOMAP)
        {
            return device->bar[i].base;
        }
    }
    return 0;
}

uint32_t PciDeviceGetIrq(pci_dev_t *device)
{
    return device->irq;
}

uint32_t PciDeviceGetMemLen(pci_dev_t *device)
{
    int i;
    for (i = 0; i < PCI_BAR_MAX; i++)
    {
        if (device->bar[i].type == PCI_BAR_MEMMAP)
        {
            return device->bar[i].len;
        }
    }
    return 0;
}

uint32_t PciDeviceGetMemAddr(pci_dev_t *device)
{
    int i;
    for (i = 0; i < PCI_BAR_MAX; i++)
    {
        if (device->bar[i].type == PCI_BAR_MEMMAP)
        {
            return device->bar[i].base;
        }
    }
    return 0;
}

void PciEnableBusMaster(pci_dev_t *device)
{
    uint32_t val = PciConfigReadD(device->bus, device->dev, device->fun, PCI_STATUS_COMMAND);
    val |= PCI_CMD_MASTER;
    PciConfigWriteD(device->bus, device->dev, device->fun, PCI_STATUS_COMMAND, val);

    KPrint("[pci] pci device bus:%d dev:%d func:%d enable bus master!\n", device->bus, device->dev, device->fun);
}

void PciDeviceAdd(uint8_t bus, uint8_t dev, uint8_t fun)
{
    int bar;

    // create device
    pci_dev_t *device = PciDeviceAlloc();
    // alloc error
    if (device == NULL)
        return;
    // init device
    PciDeviceInit(device, bus, dev, fun);
    // init bar
    PciBarInit(device, bus, dev, fun);
    // set device flags
    device->flags = PCI_DEVICE_USING;

    dev_count++; // found an pci device
}

void PciDeviceRemove(pci_dev_t *device)
{
    PciDeviceDel(device);
}

void PciBusAdd(int father, int child, int isbright, uint8_t bus)
{
    // create bus
    pci_bus_t *bus_info = PciBusAlloc();
    if (!bus_info)
        return;

    bus_info->bus = bus;
    bus_info->father = father;
    bus_info->child = child;
    bus_info->flags = PCI_BUS_USING;

    // KPrint("[pci] pci bus add father: %d child: %d isbright: %d bus: %d\n", father, child, isbright, bus);
    bus_count++; // found an pci bus
}

void PciBusRemove(pci_bus_t *bus)
{
    PciBusFree(bus);
}

void PciInit()
{
    int i, j;

    // init bus and device info alloc status
    for (i = 0; i < PCI_BUS_MAX; i++)
    {
        pci.bus[i].flags = PCI_BUS_INVALID;
    }

    for (i = 0; i < PCI_DEV_MAX; i++)
    {
        pci.device[i].flags = PCI_DEVICE_INVALID;
        for (j = 0; j < PCI_BAR_MAX; j++)
        {
            pci.device[i].bar[j].flags = PCI_BAR_INVALID;
        }
    }

    // start init pci
    pci.flags = PCI_INIT;
    // default pci type to single host controller
    pci.type = PCI_SINGLEHOST;
    // if bus 0,device 0,fun 0 is muitifun device,it is muitihost controller
    if (!PciHostControllerCheck())
    {
        pci.type = PCI_MUITIHOST;
    }
    // scan pci device
    PciScan();
    // store bus and device numbers
    pci.bus_num = bus_count;
    pci.dev_num = dev_count;
    // pci ready
    pci.flags = PCI_READY;
#ifdef DEBUG_PCI
    PciDump();
#endif
}

void PciDump()
{
#ifdef DEBUG_PCI
    int i, j;
    pci_dev_t *dev;
    pci_bus_t *bus;
    pci_bar_t *bar;

    KPrint("pci controller type: %d  bus num: %d  device num: %d\n", pci.type, pci.bus_num, pci.dev_num);
    // dump bus
    for (i = 0; i < PCI_BUS_MAX; i++)
    {
        bus = &pci.bus[i];
        if (bus->flags != PCI_BUS_USING)
            continue;
        KPrint("pci bus %d info:\n", bus->bus);
        KPrint("    isbright: %d\n", bus->bright);
        KPrint("    father: %d     child: %d\n", bus->father, bus->child);
    }
    // dump device
    for (i = 0; i < PCI_DEV_MAX; i++)
    {
        dev = &pci.device[i];
        if (dev->flags != PCI_DEVICE_USING)
            continue;

        KPrint("pci device:\n");
        KPrint("    header: %d\n", dev->header);
        KPrint("    class: %d subclass: %d\n", dev->class, dev->subclass);
        KPrint("    bus: %d dev: %d fun: %d multifun: %d\n", dev->bus, dev->dev, dev->fun, dev->muitifun);
        KPrint("    vendorID: %d    revisionID: %d\n", dev->vendorID, dev->revisionID);
        KPrint("    irq: %d\n", dev->irq);
        KPrint("    pci bar list:\n");
        for (j = 0; j < PCI_BAR_MAX; j++)
        {
            bar = &dev->bar[j];
            if (bar->flags != PCI_BAR_AVAILABLE)
                continue;
            KPrint("    Bar%d info:  type: %d base: %x len: %d\n", j, bar->type, bar->base, bar->len);
        }
    }
#endif
}